Conversation
| See also: [`Losses.focal_loss`](@ref) | ||
|
|
||
| """ | ||
| function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=epseltype(ŷ)) |
There was a problem hiding this comment.
Some have crept in & need fixing, but there should not be greek-letter keywords. These can be gamma and eps?
Also, as written, γ=1.5 will cause Float32 input to be promoted to Float64. Can you avoid this somehow? Perhaps there should be a line like γ = gamma isa Integer ? gamma : convert(eltype(logpt), gamma). (Integer powers are faster.)
| 0.665241 0.665241 0.665241 0.665241 0.665241 | ||
|
|
||
| julia> Flux.logit_focal_loss(ŷ, y) ≈ 1.1277571935622628 | ||
| true |
There was a problem hiding this comment.
This example output doesn't match what's written.
More importantly, the example is an opportunity to show exactly how this relates to focal_loss, i.e. where the softmax goes. And perhaps (if you can think of a compact & clear way) the relation to crossentropy (or rather logitcrossentropy?) too.
There was a problem hiding this comment.
Yeah i still need to work through these tests, did not realize about the docstring tests until after already putting tests elsewhere :) Can do !
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
It would be quite nice to have focal loss from logits, for numerical stability. This PR implements that! We have logit versions of crossentropy et al, so this i think has precedence!
PR Checklist